# Copyright 2023 Google LLC
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load(
    "//:build_defs.bzl",
    "xnnpack_cxx_library",
    "xnnpack_kleidiai_defines",
    "xnnpack_test_deps_for_library",
    "xnnpack_unit_test",
)
#load(
#    "//:build_params.bzl",
#    "xnnpack_select_if"
#)

MICROKERNEL_TEST_DEPS = [
    "//:buffer",
    ":next_prime",
    ":replicable_random_device",
    "//:aligned_allocator",
    "//:all_microkernels",
    "//:allocator",
    "//:common",
    "//:datatype",
    "//:fp16",
    "//:isa_checks",
    "//:math",
    "//:memory",
    "//:microkernels_h",
    "//:microparams_init",
    "//:microparams",
    "//:packing",
    "//:params",
    "//:quantization",
    "//:requantization",
    "//:xnnpack_h",
]

############################## Testing utilities ###############################

xnnpack_cxx_library(
    name = "replicable_random_device",
    testonly = True,
    hdrs = ["replicable_random_device.h"],
    deps = xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "next_prime",
    testonly = True,
    srcs = ["next_prime.cc"],
    hdrs = ["next_prime.h"],
)

xnnpack_cxx_library(
    name = "gemm_microkernel_tester",
    testonly = True,
    srcs = ["gemm-microkernel-tester.cc"],
    hdrs = ["gemm-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library() + [
        "//:XNNPACK",
        "//:config_hdrs",
    ],
)

xnnpack_cxx_library(
    name = "unary_ops",
    testonly = True,
    srcs = ["unary-ops.cc"],
    hdrs = ["unary-ops.h"],
    deps = MICROKERNEL_TEST_DEPS + ["//:reference_ukernels"] + xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "vunary_microkernel_tester",
    testonly = True,
    hdrs = ["vunary-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library() + [":unary_ops"],
)

xnnpack_cxx_library(
    name = "vbinary_microkernel_tester",
    testonly = True,
    srcs = ["vbinary-microkernel-tester.cc"],
    hdrs = ["vbinary-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "dwconv_microkernel_tester",
    testonly = True,
    srcs = ["dwconv-microkernel-tester.cc"],
    hdrs = ["dwconv-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library() + [
        "//:microkernel_utils",
    ],
)

xnnpack_cxx_library(
    name = "rdsum_microkernel_tester",
    testonly = True,
    hdrs = ["rdsum-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(),
)

xnnpack_cxx_library(
    name = "packq_microkernel_tester",
    testonly = True,
    srcs = ["packq-microkernel-tester.cc"],
    hdrs = ["packq-microkernel-tester.h"],
    deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(),
)

####################### Unit tests for microkernel lists #######################
# TODO: b/381390736 - Reenable once fixed.
#sh_test(
#    name = "microkernel_lists_test",
#    size = "small",
#    srcs = ["microkernel_lists_test.sh"],
#    data = [
#        "//:cmake_microkernel_lists",
#        "//:generated_microkernel_lists",
#        "//gen:bzl_microkernel_lists",
#    ],
#    target_compatible_with = xnnpack_select_if(
#        "//build_config:linux",
#        [],
#        ["@platforms//:incompatible"],
#    ),
#)

######################### Unit tests for micro-kernels #########################

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_TEST_DEPS + [":vunary_microkernel_tester"],
) for kernel in [
    "f16_vabs",
    "f16_vapproxgelu",
    "f16_vclamp",
    "f16_velu",
    "f16_vgelu",
    "f16_vhswish",
    "f16_vlrelu",
    "f16_vneg",
    "f16_vrndd",
    "f16_vrndne",
    "f16_vrndu",
    "f16_vrndz",
    "f16_vrsqrt",
    "f16_vsigmoid",
    "f16_vsqr",
    "f16_vsqrt",
    "f16_vtanh",
    "f32_vabs",
    "f32_vapproxgelu",
    "f32_vclamp",
    "f32_velu",
    "f32_vexp",
    "f32_vgelu",
    "f32_vhswish",
    "f32_vlog",
    "f32_vlrelu",
    "f32_vneg",
    "f32_vrelu",
    "f32_vrndd",
    "f32_vrndne",
    "f32_vrndu",
    "f32_vrndz",
    "f32_vrsqrt",
    "f32_vsigmoid",
    "f32_vsin",
    "f32_vsqr",
    "f32_vsqrt",
    "f32_vtanh",
    "s8_vclamp",
    "u8_vclamp",
    "qs8_vlrelu",
    "qu8_vlrelu",
    "f16_f32_vcvt",
    "f16_qs8_vcvt",
    "f16_qu8_vcvt",
    "f32_f16_vcvt",
    "f32_qs8_vcvt",
    "f32_qu8_vcvt",
    "qs8_f16_vcvt",
    "qs8_f32_vcvt",
    "qs8_vcvt",
    "qu8_vcvt",
    "qu8_f32_vcvt",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_TEST_DEPS + [":vbinary_microkernel_tester"],
) for kernel in [
    "f16_vadd",
    "f16_vaddc",
    "f16_vdiv",
    "f16_vdivc",
    "f16_vmax",
    "f16_vmaxc",
    "f16_vmin",
    "f16_vminc",
    "f16_vmul",
    "f16_vmulc",
    "f16_vprelu",
    "f16_vpreluc",
    "f16_vrpreluc",
    "f16_vrdivc",
    "f16_vrsubc",
    "f16_vsqrdiff",
    "f16_vsqrdiffc",
    "f16_vsub",
    "f32_vadd",
    "f32_vaddc",
    "f32_vcopysign",
    "f32_vcopysignc",
    "f32_vdiv",
    "f32_vdivc",
    "f32_vmax",
    "f32_vmaxc",
    "f32_vmin",
    "f32_vminc",
    "f32_vmul",
    "f32_vmulc",
    "f32_vprelu",
    "f32_vpreluc",
    "f32_vrpreluc",
    "f32_vrcopysignc",
    "f32_vrdivc",
    "f32_vrsubc",
    "f32_vsqrdiff",
    "f32_vsqrdiffc",
    "f32_vsub",
    "f32_vsubc",
    "qs8_vadd_minmax",
    "qs8_vaddc_minmax",
    "qs8_vmul_minmax_fp32",
    "qs8_vmul_minmax_rndnu",
    "qs8_vmulc_minmax_fp32",
    "qs8_vmulc_minmax_rndnu",
    "qu8_vadd_minmax",
    "qu8_vaddc_minmax",
    "qu8_vmul_minmax_fp32",
    "qu8_vmul_minmax_rndnu",
    "qu8_vmulc_minmax_fp32",
    "qu8_vmulc_minmax_rndnu",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "reduce-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
) for kernel in [
    "f16_rmax",
    "f16_rmin",
    "f16_rminmax",
    "f32_rmax",
    "f32_rmin",
    "f32_rminmax",
    "f32_rdmin",
    "f32_rdmax",
    "f16_rdmin",
    "f16_rdmax",
    "s8_rmin",
    "s8_rmax",
    "s8_rminmax",
    "s8_rdmin",
    "s8_rdmax",
    "u8_rmin",
    "u8_rmax",
    "u8_rdmin",
    "u8_rdmax",
    "u8_rminmax",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "rsum-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
) for kernel in [
    "f16_rsum",
    "f16_f32acc_rsum",
    "f32_rsum",
    "qs8_rsum",
    "qu8_rsum",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
        "ibilinear-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
) for kernel in [
    "f32_ibilinear",
    "f32_ibilinear_chw",
    "f16_ibilinear",
    "f16_ibilinear_chw",
    "s8_ibilinear",
    "u8_ibilinear",
]]

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    timeout = "moderate",
    srcs = [
        "%s.cc" % kernel.replace("_", "-"),
    ],
    shard_count = shard_count,
    deps = MICROKERNEL_TEST_DEPS + [
        ":dwconv_microkernel_tester",
    ],
) for (kernel, shard_count) in [
    ("f16_dwconv_minmax", 5),
    ("f32_dwconv", 5),
    ("f32_dwconv_minmax", 5),
    ("qs8_qc8w_dwconv_minmax_fp32", 10),
    ("qs8_dwconv_minmax_fp32", 10),
    ("qs8_dwconv_minmax_rndnu", 1),
    ("qu8_dwconv_minmax_fp32", 5),
    ("qu8_dwconv_minmax_rndnu", 1),
]]

xnnpack_unit_test(
    name = "maxpool_minmax_test",
    srcs = [
        "maxpool-microkernel-tester.h",
        "maxpool-minmax.cc",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "bf16_gemm_minmax_test",
    srcs = [
        "bf16-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "bf16_f32_gemm_minmax_test",
    srcs = [
        "bf16-f32-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "avgpool_minmax_test",
    srcs = ["avgpool-minmax.cc"],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_f32acc_gemm_minmax_test",
    srcs = [
        "f16-f32acc-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f16_gemm_minmax_test",
    srcs = [
        "f16-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f16_f32acc_igemm_minmax_test",
    srcs = [
        "f16-f32acc-igemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f16_igemm_minmax_test",
    srcs = [
        "f16-igemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f16_spmm_minmax_test",
    srcs = [
        "f16-spmm-minmax.cc",
        "spmm-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_vmulcaddc_minmax_test",
    srcs = [
        "f16-vmulcaddc-minmax.cc",
        "vmulcaddc-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_raddstoreexpminusmax_test",
    srcs = [
        "f16-raddstoreexpminusmax.cc",
        "raddstoreexpminusmax-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_argmaxpool_test",
    srcs = [
        "argmaxpool-microkernel-tester.h",
        "f32-argmaxpool.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_igemm_test",
    srcs = [
        "f32-igemm.cc",
        "f32-igemm-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_igemm_relu_test",
    srcs = [
        "f32-igemm-relu.cc",
        "f32-igemm-relu-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_igemm_minmax_test",
    srcs = [
        "f32-igemm-minmax.cc",
        "f32-igemm-minmax-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_conv_hwc_test",
    srcs = [
        "conv-hwc-microkernel-tester.h",
        "f32-conv-hwc.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_conv_hwc2chw_test",
    srcs = [
        "conv-hwc2chw-microkernel-tester.h",
        "f16-conv-hwc2chw.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_conv_hwc2chw_test",
    srcs = [
        "conv-hwc2chw-microkernel-tester.h",
        "f32-conv-hwc2chw.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_dwconv2d_chw_test",
    srcs = [
        "dwconv2d-microkernel-tester.h",
        "f16-dwconv2d-chw.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_dwconv2d_chw_test",
    srcs = [
        "dwconv2d-microkernel-tester.h",
        "f32-dwconv2d-chw.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_gemm_test",
    srcs = [
        "f32-gemm.cc",
        "f32-gemm-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_gemm_relu_test",
    srcs = [
        "f32-gemm-relu.cc",
        "f32-gemm-relu-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_gemm_minmax_test",
    srcs = [
        "f32-gemm-minmax.cc",
        "f32-gemm-minmax-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_gemm_goi_minmax_test",
    srcs = [
        "f32-gemm-goi-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "pf32_gemm_minmax_test",
    srcs = [
        "pf32-gemm-minmax.cc",
    ],
    defines = xnnpack_kleidiai_defines(),
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_qc8w_gemm_test",
    srcs = [
        "f32-qc8w-gemm.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_qc8w_gemm_relu_test",
    srcs = [
        "f32-qc8w-gemm-relu.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_qc4w_gemm_minmax_test",
    srcs = [
        "f32-qc4w-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_qc8w_gemm_minmax_test",
    srcs = [
        "f32-qc8w-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_gemminc_minmax_test",
    srcs = [
        "f32-gemminc-minmax.cc",
        "f32-gemminc-minmax-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_ppmm_minmax_test",
    srcs = [
        "f32-ppmm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "f32_raddexpminusmax_test",
    srcs = [
        "f32-raddexpminusmax.cc",
        "raddexpminusmax-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_raddextexp_test",
    srcs = [
        "f32-raddextexp.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_raddstoreexpminusmax_test",
    srcs = [
        "f32-raddstoreexpminusmax.cc",
        "raddstoreexpminusmax-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_f32acc_rdsum_test",
    srcs = [
        "f16-f32acc-rdsum.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [":rdsum_microkernel_tester"],
)

xnnpack_unit_test(
    name = "f32_rdsum_test",
    srcs = [
        "f32-rdsum.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [":rdsum_microkernel_tester"],
)

xnnpack_unit_test(
    name = "f32_spmm_minmax_test",
    srcs = [
        "f32-spmm-minmax.cc",
        "f32-spmm-minmax-2.cc",
        "f32-spmm-minmax-3.cc",
        "f32-spmm-minmax-4.cc",
        "spmm-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f16_vcmul_test",
    srcs = [
        "f16-vcmul.cc",
        "vcmul-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_vcmul_test",
    srcs = [
        "f32-vcmul.cc",
        "vcmul-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_vmulcaddc_minmax_test",
    srcs = [
        "f32-vmulcaddc-minmax.cc",
        "vmulcaddc-microkernel-tester.h",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_vscaleexpminusmax_test",
    srcs = [
        "f32-vscaleexpminusmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "f32_vscaleextexp_test",
    srcs = [
        "f32-vscaleextexp.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "qd8_f16_qc8w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f16-qc8w-gemm-minmax.cc",
        "qd8-f16-qc8w-gemm-minmax-2.cc",
        "qd8-f16-qc8w-gemm-minmax-3.cc",
        "qd8-f16-qc8w-gemm-minmax-4.cc",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f32_qc8w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f32-qc8w-gemm-minmax.cc",
        "qd8-f32-qc8w-gemm-minmax-2.cc",
        "qd8-f32-qc8w-gemm-minmax-3.cc",
        "qd8-f32-qc8w-gemm-minmax-4.cc",
    ],
    shard_count = 10,
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f16_qc4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f16-qc4w-gemm-minmax.cc",
        "qd8-f16-qc4w-gemm-minmax-2.cc",
        "qd8-f16-qc4w-gemm-minmax-3.cc",
        "qd8-f16-qc4w-gemm-minmax-4.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f32_qc4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f32-qc4w-gemm-minmax.cc",
        "qd8-f32-qc4w-gemm-minmax-2.cc",
        "qd8-f32-qc4w-gemm-minmax-3.cc",
        "qd8-f32-qc4w-gemm-minmax-4.cc",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f16_qb4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f16-qb4w-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f32_qb4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f32-qb4w-gemm-minmax.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qp8_f32_qc4w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qp8-f32-qc4w-gemm-minmax.cc",
    ],
    defines = xnnpack_kleidiai_defines(),
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qp8_f32_qc8w_gemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qp8-f32-qc8w-gemm-minmax.cc",
    ],
    defines = xnnpack_kleidiai_defines(),
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qs8_qc8w_gemm_minmax_fp32_test",
    timeout = "moderate",
    srcs = [
        "qs8-qc8w-gemm-minmax-fp32.cc",
        "qs8-qc8w-gemm-minmax-fp32-2.cc",
        "qs8-qc8w-gemm-minmax-fp32-3.cc",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f16_qc8w_igemm_minmax_test",
    timeout = "moderate",
    srcs = [
        "qd8-f16-qc8w-igemm-minmax.cc",
        "qd8-f16-qc8w-igemm-minmax-2.cc",
        "qd8-f16-qc8w-igemm-minmax-3.cc",
        "qd8-f16-qc8w-igemm-minmax-4.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qd8_f32_qc8w_igemm_minmax_test",
    srcs = [
        "qd8-f32-qc8w-igemm-minmax.cc",
        "qd8-f32-qc8w-igemm-minmax-2.cc",
        "qd8-f32-qc8w-igemm-minmax-3.cc",
    ],
    shard_count = 5,
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qs8_qc8w_igemm_minmax_fp32_test",
    timeout = "long",
    srcs = [
        "qs8-qc8w-igemm-minmax-fp32.cc",
        "qs8-qc8w-igemm-minmax-fp32-2.cc",
        "qs8-qc8w-igemm-minmax-fp32-3.cc",
    ],
    shard_count = 10,
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qs8_rdsum_minmax_fp32_test",
    srcs = [
        "qs8-rdsum-minmax-fp32.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [":rdsum_microkernel_tester"],
)

xnnpack_unit_test(
    name = "qu8_rdsum_test",
    srcs = [
        "qu8-rdsum.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [":rdsum_microkernel_tester"],
)

xnnpack_unit_test(
    name = "qu8_gemm_minmax_fp32_test",
    srcs = [
        "qu8-gemm-minmax-fp32.cc",
        "qu8-gemm-minmax-fp32-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qu8_gemm_minmax_rndnu_test",
    srcs = [
        "qu8-gemm-minmax-rndnu.cc",
        "qu8-gemm-minmax-rndnu-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qu8_igemm_minmax_fp32_test",
    srcs = [
        "qu8-igemm-minmax-fp32.cc",
        "qu8-igemm-minmax-fp32-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qu8_gemm_minmax_rndnu16_test",
    srcs = [
        "qu8-gemm-minmax-rndnu16.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "qu8_igemm_minmax_rndnu_test",
    srcs = [
        "qu8-igemm-minmax-rndnu.cc",
        "qu8-igemm-minmax-rndnu-2.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":gemm_microkernel_tester",
    ],
)

xnnpack_unit_test(
    name = "u8_lut32norm_test",
    srcs = [
        "lut-norm-microkernel-tester.h",
        "u8-lut32norm.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x8_lut_test",
    srcs = [
        "lut-microkernel-tester.h",
        "x8-lut.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x8_packq_test",
    srcs = [
        "x8-packq.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        ":packq_microkernel_tester",
    ],
)

[xnnpack_unit_test(
    name = "%s_test" % kernel,
    srcs = [
        "packw-microkernel-tester.h",
        "%s.cc" % kernel.replace("_", "-"),
    ],
    deps = MICROKERNEL_TEST_DEPS,
) for kernel in [
    "x8_packw",
    "qs8_packw",
    "qs8_qc4w_packw",
    "x16_packw",
    "x32_packw",
]]

xnnpack_unit_test(
    name = "x32_packx_test",
    srcs = [
        "pack-microkernel-tester.h",
        "x32-packx.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "xN_transpose_test",
    srcs = ["xN-transpose.cc"],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "x32_unpool_test",
    srcs = [
        "unpool-microkernel-tester.h",
        "x32-unpool.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "xx_fill_test",
    srcs = ["xx-fill.cc"],
    deps = MICROKERNEL_TEST_DEPS,
)

xnnpack_unit_test(
    name = "xx_pad_test",
    srcs = ["xx-pad.cc"],
    deps = MICROKERNEL_TEST_DEPS,
)

############################### Misc unit tests ###############################

xnnpack_unit_test(
    name = "buffer_test",
    srcs = [
        "buffer.cc",
    ],
    deps = [
        ":replicable_random_device",
        "//:buffer",
        "//:math",
    ],
)

xnnpack_unit_test(
    name = "weights_cache_test",
    srcs = ["weights-cache.cc"],
    deps = [
        "//:XNNPACK",
        "//:cache",
        "//:common",
        "//:memory",
    ],
)

xnnpack_unit_test(
    name = "mutex_test",
    srcs = ["mutex.cc"],
    deps = [
        ":replicable_random_device",
        "//:common",
        "//:mutex",
        "//:xnnpack_h",
    ],
)

xnnpack_unit_test(
    name = "microkernel_utils_test",
    srcs = ["microkernel-utils.cc"],
    deps = [
        ":replicable_random_device",
        "//:math",
        "//:microkernel_utils",
        "//:params",
    ],
)

xnnpack_unit_test(
    name = "packing_test",
    srcs = [
        "packing.cc",
    ],
    deps = MICROKERNEL_TEST_DEPS + [
        "//:microkernel_utils",
        "//:operator_utils",
    ],
)

xnnpack_unit_test(
    name = "indirection_test",
    srcs = [
        "indirection.cc",
    ],
    deps = [
        "//:buffer",
        "//:indirection",
        "//:math",
        "//:operator_utils",
        "//:operators",
        "//:xnnpack_h",
    ],
)

xnnpack_unit_test(
    name = "build_identifier_test",
    srcs = [
        "build-identifier.cc",
    ],
    deps = [
        "//:XNNPACK",
    ],
)
